"""
General utility functions.
"""

import os
import json
from pathlib import Path
from typing import Callable, Tuple
import csv
from typing import List, Optional
from datetime import datetime

from pts.object import BarData, HistoryRequest, Interval


def _get_trader_dir(temp_name: str) -> Tuple[Path, Path]:
    """
    Get path where trader is running in.
    """
    cwd = Path.home()
    temp_path = cwd.joinpath(temp_name)

    # If .vntrader folder exists in current working directory,
    # then use it as trader running path.
    if temp_path.exists():
        return temp_path

    if not temp_path.exists():
        temp_path.mkdir()

    return temp_path


TEMP_DIR = _get_trader_dir(".pts")


def get_file_path(filename: str) -> Path:
    """
    Get path for temp file with filename.
    """
    return TEMP_DIR.joinpath(filename)


def get_icon_path(filepath: str, ico_name: str) -> str:
    """
    Get path for icon file with ico name.
    """
    ui_path = Path(filepath).parent
    icon_path = ui_path.joinpath("ico", ico_name)
    return str(icon_path)


def load_json(filename: str) -> dict:
    """
    Load data from json file in temp path.
    """
    filepath = get_file_path(filename)

    if filepath.exists():
        with open(filepath, mode="r", encoding="UTF-8") as f:
            data = json.load(f)
        return data
    else:
        save_json(filename, {})
        return {}


def save_json(filename: str, data: dict) -> None:
    """
    Save data into json file in temp path.
    """
    filepath = get_file_path(filename)
    with open(filepath, mode="w+", encoding="UTF-8") as f:
        json.dump(
            data,
            f,
            indent=4,
            ensure_ascii=False
        )

def virtual(func: Callable) -> Callable:
    """
    mark a function as "virtual", which means that this function can be override.
    any base class should use this or @abstractmethod to decorate all functions
    that can be (re)implemented by subclasses.
    """
    return func


def import_data_from_csv(req: dict) -> Optional[List[BarData]]:
    """从CSV文件中读取行情数据并写入数据库"""
    symbol = req['symbol']
    exchange = req['exchange']
    interval = req['interval']
    file_path = req['KLineData']
    bars = []  # 保存K线数据的列表

    # 打开CSV文件
    file_name = file_path + symbol + '_' + interval.value + '.csv'
    if not os.path.exists(file_name):
        return bars

    datetime_head = 'datetime'
    if interval == Interval.DAILY:
        datetime_format = '%Y-%m-%d'
    else:
        datetime_format = '%Y-%m-%d %H:%M:%S'

    with open(file_name, "rt") as f:
        buf = [line.replace("\0", "") for line in f]

    # 创建一个csv模块的reader对象
    reader = csv.DictReader(buf, delimiter=",")

    # 与计算MA5有关
    counter = 0
    ma5 = 0.0

    # 读取CSV文件的每一行
    for item in reader:
        # 取日期时间字符串
        if datetime_format:
            # 如果界面上有日期时间字符串格式的定义
            dt = datetime.strptime(item[datetime_head], datetime_format)
        else:
            # 如果界面上没有日期时间字符串格式的定义，按标准格式进行解析
            dt = datetime.fromisoformat(item[datetime_head])

        # 与计算MA5有关
        counter = counter + 1
        ma5 += float(item['close'])
        if counter > 5:
            ma5 -= bars[-5].close_price;

        # 用读取的数据创建一个K线数据对象
        bar = BarData(
            symbol=symbol,
            exchange=exchange,
            datetime=dt,
            interval=interval,
            volume=float(item['volume']),
            open_price=float(item['open']),
            high_price=float(item['high']),
            low_price=float(item['low']),
            close_price=float(item['close']),
            open_interest=float(item['amount']),
        )

        # 加入K线数据列表
        bars.append(bar)

    return bars
